//////////////////////////////////////////////
// main.cpp
//
//////////////////////////////////////////////

/// Includes ---------------------------------

// Local
#include "PastCode.h"

// nkGraphics
#include <NilkinsGraphics/Buffers/Buffer.h>
#include <NilkinsGraphics/Buffers/BufferManager.h>

#include <NilkinsGraphics/Compositors/Compositor.h>
#include <NilkinsGraphics/Compositors/CompositorManager.h>
#include <NilkinsGraphics/Compositors/CompositorNode.h>
#include <NilkinsGraphics/Compositors/TargetOperations.h>

#include <NilkinsGraphics/Passes/ClearTargetsPass.h>
#include <NilkinsGraphics/Passes/PostProcessPass.h>
#include <NilkinsGraphics/Passes/RaytracingPass.h>

#include <NilkinsGraphics/RenderContexts/RenderContext.h>
#include <NilkinsGraphics/RenderContexts/RenderContextDescriptor.h>
#include <NilkinsGraphics/RenderContexts/RenderContextManager.h>

#include <NilkinsGraphics/Samplers/Sampler.h>
#include <NilkinsGraphics/Samplers/SamplerManager.h>

#include <NilkinsGraphics/Textures/Texture.h>
#include <NilkinsGraphics/Textures/TextureManager.h>

#include <NilkinsGraphics/Renderers/Renderer.h>

#include <NilkinsGraphics/System.h>

/// Internals : Shaders ----------------------

void prepareRaygenMissProgram ()
{
	nkGraphics::Program* raygenMissProgram = nkGraphics::ProgramManager::getInstance()->createOrRetrieve("raygenMiss") ;
	nkGraphics::ProgramSourcesHolder sources ;

	// This program will only use compute stage
	sources.setRaytracingMemory
	(
		R"eos(
			struct RayPayload
			{
				float4 color ;
				float depth ;
			} ;

			cbuffer PassConstants : register(b0)
			{
				uint4 texInfos ;
				float4 camPos ;
				matrix invView ;
				matrix invProj ;
			}

			RaytracingAccelerationStructure scene : register(t0) ;
			RWStructuredBuffer<float4> output : register(u0) ;

			// This time, we will add the environment map and a sampler to be able to use it in one or more stages
			TextureCube envMap : register(t1) ;
			SamplerState envSampler : register(s0) ;

			[shader("raygeneration")]
			void raygen ()
			{
				// Compute Ray's origin, as simple units
				float2 dispatchIndex = DispatchRaysIndex().xy ;
				float2 pixCenter = dispatchIndex.xy + 0.5 ;
				float2 uvs = pixCenter / texInfos.xy * 2.0 - 1.0 ;
				uvs.y = -uvs.y ;
				
				float3 pixelOrigin = camPos.xyz ;
				float4 pixelDir = mul(invView, mul(invProj, float4(uvs, 0, 1))) ;
				pixelDir.xyz /= pixelDir.w ;
				float3 pixelDirVec3 = normalize(pixelDir.xyz - pixelOrigin) ;

				// Trace the ray
				RayDesc ray ;
				ray.Origin = pixelOrigin ;
				ray.Direction = pixelDirVec3 ;
				ray.TMin = 0.001 ;
				ray.TMax = 100.0 ;
				RayPayload payload = {float4(1, 1, 1, 1), 1} ;

				// Simple tracing
				TraceRay(scene, RAY_FLAG_NONE, ~0, 0, 1, 0, ray, payload) ;

				// And writing
				uint index = dispatchIndex.y * texInfos.x + dispatchIndex.x ;
				
				output[index] = float4(payload.color.xyz, 1) ;
			}

			[shader("miss")]
			void miss (inout RayPayload payload)
			{
				// When we miss a geometry, we will sample the environment map to make a nice background
				// Notice that we need to specify the mip level to sample : consider raytracing programs as compute programs, Sample is unavailable
				payload.color = envMap.SampleLevel(envSampler, normalize(WorldRayDirection()), 0) ;
			}
		)eos"
	) ;

	raygenMissProgram->setFromMemory(sources) ;
	raygenMissProgram->load() ;
}

void prepareRaygenMissShader ()
{
	// Prepare the shader used by the pass for raygen miss purposes
	nkGraphics::Shader* raygenMissShader = nkGraphics::ShaderManager::getInstance()->createOrRetrieve("raygenMiss") ;
	nkGraphics::Program* raygenMissProgram = nkGraphics::ProgramManager::getInstance()->get("raygenMiss") ;

	raygenMissShader->setProgram(raygenMissProgram) ;

	// Constant Buffer needs many information
	nkGraphics::ConstantBuffer* cBuffer = raygenMissShader->addConstantBuffer(0) ;

	// We will find it through our offscreen texture
	nkGraphics::ShaderPassMemorySlot* slot = cBuffer->addPassMemorySlot() ;
	slot->setAsTargetSize() ;

	slot = cBuffer->addPassMemorySlot() ;
	slot->setAsCameraPosition() ;

	slot = cBuffer->addPassMemorySlot() ;
	slot->setAsViewMatrixInv() ;

	slot = cBuffer->addPassMemorySlot() ;
	slot->setAsProjectionMatrixInv() ;

	// The scene acceleration structure is accessible through the render queue
	nkGraphics::RenderQueue* rq = nkGraphics::RenderQueueManager::getInstance()->get(nkGraphics::RenderQueueManager::DEFAULT_RENDER_QUEUE) ;
	raygenMissShader->addTexture(rq->getAccelerationStructureBuffer(), 0) ;

	// Texture and sampler
	nkGraphics::Texture* tex = nkGraphics::TextureManager::getInstance()->get("tex") ;
	raygenMissShader->addTexture(tex, 1) ;

	nkGraphics::Sampler* sampler = nkGraphics::SamplerManager::getInstance()->get("sampler") ;
	raygenMissShader->addSampler(sampler, 0) ;

	nkGraphics::Buffer* buffer = nkGraphics::BufferManager::getInstance()->get("raytracedBuffer") ;
	raygenMissShader->addUavBuffer(buffer, 0) ;

	// Finalize loading
	raygenMissShader->load() ;
}

void prepareReflectionHitProgram ()
{
	nkGraphics::Program* hitProgram = nkGraphics::ProgramManager::getInstance()->createOrRetrieve("reflectionHit") ;
	nkGraphics::ProgramSourcesHolder sources ;

	// This program will only use compute stage
	sources.setRaytracingMemory
	(
		R"eos(
			struct RayPayload
			{
				float4 color ;
				float depth ;
			} ;

			// Vertex data composition
			// Needs to be aligned on the mesh layout itself
			struct VertexData
			{
				float3 position ;
				float2 uvs ;
				float3 normal ;
			} ;

			RaytracingAccelerationStructure scene : register(t0) ;
			TextureCube envMap : register(t1) ;

			// Mesh data, naming is important as it is how nkGraphics can get the slots back
			StructuredBuffer<VertexData> _vertexData : register(t2) ;

			SamplerState envSampler : register(s0) ;

			[shader("closesthit")]
			void closestHit (inout RayPayload payload, in BuiltInTriangleIntersectionAttributes attr)
			{
				// Compute barycentrics
				float3 bary = float3(1.0 - attr.barycentrics.x - attr.barycentrics.y, attr.barycentrics.x, attr.barycentrics.y) ;
				uint primitiveIndex = PrimitiveIndex() ;

				// Find back all points constituting triangle
				float3 a = _vertexData[primitiveIndex * 3 + 0].normal ;
				float3 b = _vertexData[primitiveIndex * 3 + 1].normal ;
				float3 c = _vertexData[primitiveIndex * 3 + 2].normal ;

				// Recompute from barycenters in triangle
				float3 hitPosition = WorldRayOrigin() + WorldRayDirection() * RayTCurrent() ;
				float3 foundNormal = normalize(a * bary.x + b * bary.y + c * bary.z) ;
				float3 worldDir = normalize(WorldRayDirection()) ;

				// Prepare new ray to fire
				RayDesc ray ;
				ray.Origin = hitPosition ;
				ray.Direction = normalize(reflect(worldDir, foundNormal)) ;
				ray.TMin = 0.001 ;
				ray.TMax = 100.0 ;

				// Update payload through a new tracing
				if (payload.depth < 6)
				{
					payload.depth += 1.0 ;
					TraceRay(scene, RAY_FLAG_NONE, ~0, 0, 1, 0, ray, payload) ;
				}
				else
					payload.color = envMap.SampleLevel(envSampler, worldDir, 0) ;
			}
		)eos"
	) ;

	hitProgram->setFromMemory(sources) ;
	hitProgram->load() ;
}

void prepareReflectionHitShader ()
{
	// Prepare the shader used by the pass for raygen miss purposes
	nkGraphics::Shader* hitShader = nkGraphics::ShaderManager::getInstance()->createOrRetrieve("reflectionHit") ;
	nkGraphics::Program* hitProgram = nkGraphics::ProgramManager::getInstance()->get("reflectionHit") ;

	hitShader->setProgram(hitProgram) ;

	// The scene acceleration structure is accessible through the render queue
	nkGraphics::RenderQueue* rq = nkGraphics::RenderQueueManager::getInstance()->get(nkGraphics::RenderQueueManager::DEFAULT_RENDER_QUEUE) ;
	hitShader->addTexture(rq->getAccelerationStructureBuffer(), 0) ;

	// Give it related texture info
	nkGraphics::Texture* envMap = nkGraphics::TextureManager::getInstance()->get("tex") ;
	hitShader->addTexture(envMap, 1) ;

	nkGraphics::Sampler* envSampler = nkGraphics::SamplerManager::getInstance()->get("sampler") ;
	hitShader->addSampler(envSampler, 0) ;

	// Finalize loading
	hitShader->load() ;
}

/// Internals : Compositor -------------------

nkGraphics::Compositor* prepareCompositor ()
{
	// Prepare the shaders passes will require
	nkGraphics::Shader* envShader = nkGraphics::ShaderManager::getInstance()->get("envShader") ;
	nkGraphics::Shader* raygenMissShader = nkGraphics::ShaderManager::getInstance()->get("raygenMiss") ;
	nkGraphics::Shader* bufferCopyShader = nkGraphics::ShaderManager::getInstance()->get("bufferCopy") ;

	// Prepare the rq
	nkGraphics::RenderQueue* rq = nkGraphics::RenderQueueManager::getInstance()->get(nkGraphics::RenderQueueManager::DEFAULT_RENDER_QUEUE) ;

	// Get the compositor
	nkGraphics::Compositor* compositor = nkGraphics::CompositorManager::getInstance()->createOrRetrieve("compositor") ;
	nkGraphics::CompositorNode* node = compositor->addNode() ;

	// First operation will render offscreen, but still use the context's depth buffer (no need for another)
	nkGraphics::TargetOperations* targetOp = node->addOperations() ;
	targetOp->setToBackBuffer(true) ;
	targetOp->setToChainDepthBuffer(true) ;

	// Unroll our passes
	nkGraphics::ClearTargetsPass* clearPass = targetOp->addClearTargetsPass() ;

	nkGraphics::RaytracingPass* rtPass = targetOp->addRaytracingPass() ;
	rtPass->setRq(rq) ;
	rtPass->setRaygenMissShader(raygenMissShader) ;
	rtPass->setWidth(800) ;
	rtPass->setHeight(600) ;
	rtPass->setMaxTraceRecursionDepth(6) ;

	nkGraphics::PostProcessPass* postProcessPass = targetOp->addPostProcessPass() ;
	postProcessPass->setShader(bufferCopyShader) ;

	return compositor ;
}

/// Internals : Scene ------------------------

void prepareRaytracingInScene ()
{
	// Do like last tutorial
	nkGraphics::RenderQueue* rq = nkGraphics::RenderQueueManager::getInstance()->get(nkGraphics::RenderQueueManager::DEFAULT_RENDER_QUEUE) ;
	rq->setRaytraced(true) ;

	nkGraphics::Shader* shader = nkGraphics::ShaderManager::getInstance()->get("reflectionHit") ;
	nkGraphics::Entity* ent = rq->getEntity(0) ;
	ent->getRenderInfo().getSlots()[0]->getLods()[0]->setRaytracingShader(shader) ;

	// Add a second sphere to the scene for reflections this time
	shader = nkGraphics::ShaderManager::getInstance()->get("reflectionHit") ;
	nkGraphics::Mesh* sphere = nkGraphics::MeshManager::getInstance()->get("Mesh") ;

	ent = rq->addEntity() ;
	ent->setRenderInfo(nkGraphics::EntityRenderInfo(sphere, nullptr)) ;
	ent->getRenderInfo().getSlots()[0]->getLods()[0]->setRaytracingShader(shader) ;

	nkGraphics::Node* node = nkGraphics::NodeManager::getInstance()->createOrRetrieve("secondNode") ;
	node->setPositionAbsolute(nkMaths::Vector(0.f, 5.f, 0.f)) ;
	ent->setParentNode(node) ;
}

/// Function ---------------------------------

int main ()
{
	// Prepare for logging
	std::unique_ptr<nkLog::Logger> logger = std::make_unique<nkLog::ConsoleLogger>() ;
	nkGraphics::LogManager::getInstance()->setReceiver(logger.get()) ;

	// For easiness
	nkResources::ResourceManager::getInstance()->setWorkingPath("Data") ;

	// Initialize and create context with window
	if (!nkGraphics::System::getInstance()->initialize())
		return -1 ;

	// Query now what the hardware is capable of
	nkGraphics::RendererSupportInfo supportInfo = nkGraphics::System::getInstance()->getRenderer()->getRendererSupportInfo() ;

	if (!supportInfo._supportsRaytracing)
	{
		logger->log("Current hardware does not support raytracing, this tutorial cannot be run on this machine.", "RtxTutorial") ;

		system("pause") ;
		return 0 ;
	}

	// Basic resource preparations
	baseInit() ;
	prepareRaytracedBuffer() ;

	// Raygen and miss shaders
	prepareRaygenMissProgram() ;
	prepareRaygenMissShader() ;

	// Hit shader
	prepareReflectionHitProgram() ;
	prepareReflectionHitShader() ;

	prepareRaytracingInScene() ;

	// Prepare the composition once everything is ready
	nkGraphics::Compositor* compositor = prepareCompositor() ;
	
	// Use the compositor for the context we just created
	nkGraphics::RenderContext* context = nkGraphics::RenderContextManager::getInstance()->createRenderContext(nkGraphics::RenderContextDescriptor(800, 600, false, true)) ;
	context->setCompositor(compositor) ;

	// And trigger the rendering
	renderLoop(context) ;

	// Clean exit
	nkGraphics::System::getInstance()->kill() ;

	system("pause") ;

	return 0 ;
}